#!/usr/bin/env escript

%%
%% %CopyrightBegin%
%% 
%% Copyright Ericsson AB 2018-2023. All Rights Reserved.
%% 
%% Licensed under the Apache License, Version 2.0 (the "License");
%% you may not use this file except in compliance with the License.
%% You may obtain a copy of the License at
%%
%%     http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing, software
%% distributed under the License is distributed on an "AS IS" BASIS,
%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
%% See the License for the specific language governing permissions and
%% limitations under the License.
%% 
%% %CopyrightEnd%
%%

%% ==========================================================================
%%
%% This is a simple wrapper escript on top of the socket ttest program(s).
%% The idea is to make it simple to run in a normal shell (bash).
%%
%% ==========================================================================

-define(SECS(I), timer:seconds(I)).

-define(CLIENT_MSG_1_MAX_OUTSTANDING, 100).
-define(CLIENT_MSG_2_MAX_OUTSTANDING,  10).
-define(CLIENT_MSG_3_MAX_OUTSTANDING,   1).

main(Args) ->
    State = process_args(Args),
    exec(State),
    ok.

usage(ErrorString) when is_list(ErrorString) ->
    eprint(ErrorString),
    usage(),
    erlang:halt(0).

usage() ->
    io:format("usage: ~s [options]"
              "~n"
              "~n   This erlang script is used to start the (e)socket ttest "
              "~n   units (server or client)."
              "~n"
              "~n   options: "
              "~n     --help                       Display this info and exit. "
              "~n     --server [server-options]    Start a server. "
              "~n                                  There are no mandatory server options."
              "~n     --client client-options      Start a client"
              "~n                                  Some client options are mandatory and"
              "~n                                  others optional."
              "~n     --domain <domain>            local | inet | inet6"
              "~n                                  Which domain to use."
              "~n                                  Only valid for server."
              "~n                                  Defaults to: inet"
              "~n     --async                      Asynchronous mode (Timeout = nowait)"
              "~n                                  This option is only valid for transport = sock."
              "~n                                  Also, its only used when active =/= false."
              "~n     --active <active>            pos_integer() | boolean() | once."
              "~n                                  Valid for both client and server."
              "~n                                  Defaults to: false"
              "~n     --transport <transport>      Which transport to use: gen|sock[:plain|msg]"
              "~n                                     gen:  gen_tcp"
              "~n                                     gs:   gen_tcp with socket backend"
              "~n                                     sock: socket"
              "~n                                          plain: recv/send (default)"
              "~n                                          msg:   recvmsg/sendmsg"
              "~n                                  Defaults to: sock:plain"
              "~n     --scon <addr>:<port>|<path>  Server info."
              "~n                                  The address part is in the standard form:"
              "~n                                       \"a.b.c.d\"."
              "~n                                  <path> is used for Unix Domain sockets (local)."
              "~n                                  Only valid, and mandatory, for client."
              "~n     --msg-id <1|2|3>             Choose which message to use during the test."
              "~n                                  Basically: "
              "~n                                     1: small"
              "~n                                     2: medium"
              "~n                                     3: large"
              "~n                                  Defaults to: 1"
              "~n     --max-outstanding <Num>      How many messages to send before waiting for"
              "~n                                  a reply."
              "~n                                  Valid only for client."
              "~n                                  Defaults to: "
              "~n                                     MsgID 1: 100"
              "~n                                     MsgID 2:  10"
              "~n                                     MsgID 3:   1"
              "~n     --runtime <Time>             Time of the test in seconds."
              "~n                                  Only valid for client."
              "~n                                  Mandatory."
              "~n                                  Defaults to: 60 (seconds)"
              "~n     --profile                    Run profiling. "
              "~n"
              "~n"
              "~n", 
              [scriptname()]),
    ok.

process_args(["--help"|_]) ->
    usage();
process_args(["--server"|ServerArgs]) ->
    process_server_args(ServerArgs);
process_args(["--client"|ClientArgs]) ->
    process_client_args(ClientArgs);
process_args(Args) ->
    usage(f("Invalid Args: "
            "~n   ~p", [Args])).


process_server_args(Args) ->
    Defaults = #{role      => server,
                 domain    => inet,
                 async     => false,
                 active    => false,
                 transport => {sock, plain},
                 profile   => false},
    process_server_args(Args, Defaults).

process_server_args([], State) ->
    State;

process_server_args(["--profile"|Args], State) ->
    process_server_args(Args, State#{profile => true});

process_server_args(["--domain", Domain|Args], State) 
  when (Domain =:= "local") orelse
       (Domain =:= "inet")  orelse
       (Domain =:= "inet6") ->
    process_server_args(Args, State#{domain => list_to_atom(Domain)});

process_server_args(["--async"|Args], State) ->
    process_server_args(Args, State#{async => true});

process_server_args(["--active", Active|Args], State) 
  when ((Active =:= "false") orelse
        (Active =:= "once") orelse
        (Active =:= "true")) ->
    process_server_args(Args, State#{active => list_to_atom(Active)});

process_server_args(["--transport", "gs" | Args], State) ->
    process_server_args(Args, State#{transport => gs});
process_server_args(["--transport", "gen" | Args], State) ->
    process_server_args(Args, State#{transport => gen});
process_server_args(["--transport", "sock" | Args], State) ->
    process_server_args(Args, State#{transport => {sock, plain}});
process_server_args(["--transport", "sock:plain" | Args], State) ->
    process_server_args(Args, State#{transport => {sock, plain}});
process_server_args(["--transport", "sock:msg" | Args], State) ->
    process_server_args(Args, State#{transport => {sock, msg}});

process_server_args([Arg|_], _State) ->
    usage(f("Invalid Server arg: ~s", [Arg])).


process_client_args(Args) ->
    Defaults = #{role            => client,
                 async           => false,
                 active          => false,
                 transport       => {sock, plain},
                 %% Will cause error if not provided
                 %% Should be "addr:port or string()
                 server          => undefined,
                 msg_id          => 1,
                 %% Will be filled in based on msg_id if not provided
                 max_outstanding => undefined,
                 runtime         => ?SECS(60),
                 profile         => false},
    process_client_args(Args, Defaults).

process_client_args([], State) ->
    process_client_args_ensure_max_outstanding(State);

process_client_args(["--profile"|Args], State) ->
    process_client_args(Args, State#{profile => true});

process_client_args(["--async"|Args], State) ->
    process_client_args(Args, State#{async => true});

process_client_args(["--active", Active0|Args], State) ->
    Active =
        if
            (Active0 =:= "false") orelse
            (Active0 =:= "once") orelse
            (Active0 =:= "true") ->
                list_to_atom(Active0);
            true ->
                try list_to_integer(Active0) of
                    N when (N > 0) andalso (N =< 32767) ->
                        N;
                    _  ->
                        usage(f("Invalid Active: ~s", [Active0]))
                catch
                    _:_:_ ->
                        usage(f("Invalid Active: ~s", [Active0]))
                end
        end,
    process_client_args(Args, State#{active => Active});

process_client_args(["--transport", "gen" | Args], State) ->
    process_client_args(Args, State#{transport => gen});
process_client_args(["--transport", "gs" | Args], State) ->
    process_client_args(Args, State#{transport => gs});
process_client_args(["--transport", "sock" | Args], State) ->
    process_client_args(Args, State#{transport => {sock, plain}});
process_client_args(["--transport", "sock:plain" | Args], State) ->
    process_client_args(Args, State#{transport => {sock, plain}});
process_client_args(["--transport", "sock:msg" | Args], State) ->
    process_client_args(Args, State#{transport => {sock, msg}});

process_client_args(["--msg-id", MsgID|Args], State)
  when ((MsgID =:= "1") orelse
        (MsgID =:= "2") orelse
        (MsgID =:= "3")) ->
    process_client_args(Args, State#{msg_id => list_to_integer(MsgID)});

process_client_args(["--max-outstanding", Max|Args], State) ->
    try list_to_integer(Max) of
        I when (I > 0) ->
            process_client_args(Args, State#{max_outstanding => I});
        _ ->
            usage(f("Invalid Max Outstanding: ~s", [Max]))
    catch
        _:_:_ ->
            usage(f("Invalid Max Outstanding: ~s", [Max]))
    end;

process_client_args(["--scon", Server|Args], State) ->
    case string:split(Server, ":", trailing) of
        [AddrStr,PortStr] ->
            Addr = case inet:parse_address(AddrStr) of
                       {ok, A} ->
                           A;
                       {error, _} ->
                           usage(f("Invalid Server Address: ~s", [AddrStr]))
                   end,
            Port = try list_to_integer(PortStr) of
                       I when (I > 0) ->
                           I;
                       _ ->
                           usage(f("Invalid Server Port: ~s", [PortStr]))
                   catch
                       _:_:_ ->
                           usage(f("Invalid Server Port: ~s", [PortStr]))
                   end,
            process_client_args(Args, State#{server => {Addr, Port}});
        [Path] ->
            process_client_args(Args, State#{server => Path});
        _ ->
            usage(f("Invalid Server: ~s", [Server]))
    end;

process_client_args(["--runtime", T|Args], State) ->
    try list_to_integer(T) of
        I when (I > 0) ->
            process_client_args(Args, State#{runtime => ?SECS(I)});
        _ ->
            usage(f("Invalid Run Time: ~s", [T]))
    catch
        _:_:_ ->
            usage(f("Invalid Run Time: ~s", [T]))
    end;

process_client_args([Arg|_], _State) ->
    usage(f("Invalid Client arg: ~s", [Arg])).


process_client_args_ensure_max_outstanding(
  #{msg_id          := 1,
    max_outstanding := undefined} = State) ->
    State#{max_outstanding => ?CLIENT_MSG_1_MAX_OUTSTANDING};
process_client_args_ensure_max_outstanding(
  #{msg_id          := 2,
    max_outstanding := undefined} = State) ->
    State#{max_outstanding => ?CLIENT_MSG_2_MAX_OUTSTANDING};
process_client_args_ensure_max_outstanding(
  #{msg_id          := 3,
    max_outstanding := undefined} = State) ->
    State#{max_outstanding => ?CLIENT_MSG_3_MAX_OUTSTANDING};
process_client_args_ensure_max_outstanding(
  #{msg_id          := MsgID,
    max_outstanding := MaxOutstanding} = State)
  when ((MsgID =:= 1) orelse
        (MsgID =:= 2) orelse
        (MsgID =:= 3)) andalso 
       (is_integer(MaxOutstanding) andalso (MaxOutstanding > 0)) ->
    State;
process_client_args_ensure_max_outstanding(
  #{msg_id          := MsgID,
    max_outstanding := MaxOutstanding}) ->
    usage(f("Invalid Msg ID (~w) and Max Outstanding (~w)", 
            [MsgID, MaxOutstanding])).



%% ==========================================================================

exec(#{role      := server,
       domain    := Domain,
       active    := Active,
       transport := gen,
       profile   := Profile})
  when (Domain =:= inet) orelse (Domain =:= inet6) ->
    Fun =
        fun() ->
                case socket_test_ttest_tcp_server_gen:start(Domain, Active) of
                    {ok, {Pid, _}} ->
                        MRef = erlang:monitor(process, Pid),
                        receive
                            {'DOWN', MRef, process, Pid, Info} ->
                                Info
                        end;
                    {error, Reason} ->
                        eprint(f("Failed starting server: "
                                 "~n   ~p", [Reason])),
                        error
                end
        end,
    case Profile of
        true ->
            socket_test_profile:profile(
              f("server-gen-~w", [Active]), Fun);
        false ->
            Fun()
    end;
exec(#{role      := server,
       domain    := Domain,
       active    := Active,
       transport := gs,
       profile   := Profile})
  when (Domain =:= inet) orelse (Domain =:= inet6) ->
    Fun =
        fun() ->
                case socket_test_ttest_tcp_server_gs:start(Domain, Active) of
                    {ok, {Pid, _}} ->
                        MRef = erlang:monitor(process, Pid),
                        receive
                            {'DOWN', MRef, process, Pid, Info} ->
                    Info
                        end;
                    {error, Reason} ->
                        eprint(f("Failed starting server: "
                                 "~n   ~p", [Reason])),
                        error
                end
        end,
    case Profile of
        true ->
            socket_test_profile:profile(
              f("server-gs-~w", [Active]), Fun);
        false ->
            Fun()
    end;
exec(#{role      := server,
       domain    := Domain,
       async     := Async,
       active    := Active,
       transport := {sock, Method},
       profile   := Profile}) ->
    Fun =
        fun() ->
                case socket_test_ttest_tcp_server_socket:start(Method, Domain,
                                                               Async, Active) of
                    {ok, {Pid, _}} ->
                        MRef = erlang:monitor(process, Pid),
                        receive
                            {'DOWN', MRef, process, Pid, Info} ->
                                Info
                        end;
                    {error, Reason} ->
                        eprint(f("Failed starting server: "
                                 "~n   ~p", [Reason])),
                        error
                end
        end,
    case Profile of
        true ->
            socket_test_profile:profile(
              f("server-sock-~w-~w", [Method, Active]), Fun);
        false ->
            Fun()
    end;

exec(#{role   := client,
       server := undefined}) ->
    usage("Mandatory option 'server' not provided");
exec(#{role            := client,
       server          := {_Addr, _Port} = ServerInfo,
       active          := Active,
       transport       := gen,
       msg_id          := MsgID,
       max_outstanding := MaxOutstanding,
       runtime         := RunTime,
       profile         := Profile}) ->
    Fun =
        fun() ->
                case socket_test_ttest_tcp_client_gen:start(true,
                                                            ServerInfo,
                                                            Active,
                                                            MsgID,
                                                            MaxOutstanding,
                                                            RunTime) of
                    {ok, Pid} ->
                        MRef = erlang:monitor(process, Pid),
                        receive
                            {'DOWN', MRef, process, Pid, Info} ->
                                Info
                        end;
                    {error, Reason} ->
                        eprint(f("Failed starting server: "
                                 "~n   ~p", [Reason])),
                        error
                end
        end,
    case Profile of
        true ->
            socket_test_profile:profile(
              f("client-gen-~w", [Active]), Fun);
        false ->
            Fun()
    end;
exec(#{role            := client,
       server          := {_Addr, _Port} = ServerInfo,
       active          := Active,
       transport       := gs,
       msg_id          := MsgID,
       max_outstanding := MaxOutstanding,
       runtime         := RunTime,
       profile         := Profile}) ->
    Fun =
        fun() ->
                case socket_test_ttest_tcp_client_gs:start(true,
                                                           ServerInfo,
                                                           Active,
                                                           MsgID,
                                                           MaxOutstanding,
                                                           RunTime) of
                    {ok, Pid} ->
                        MRef = erlang:monitor(process, Pid),
                        receive
                            {'DOWN', MRef, process, Pid, Info} ->
                                Info
                        end;
                    {error, Reason} ->
                        eprint(f("Failed starting server: "
                                 "~n   ~p", [Reason])),
                        error
                end
        end,
    case Profile of
        true ->
            socket_test_profile:profile(
              f("client-gs-~w", [Active]), Fun);
        false ->
            Fun()
    end;
exec(#{role            := client,
       server          := ServerInfo,
       async           := Async,
       active          := Active,
       transport       := {sock, Method},
       msg_id          := MsgID,
       max_outstanding := MaxOutstanding,
       runtime         := RunTime,
       profile         := Profile}) ->
    Fun =
        fun() ->
                case socket_test_ttest_tcp_client_socket:start(true,
                                                               Async,
                                                               Active,
                                                               Method,
                                                               ServerInfo,
                                                               MsgID,
                                                               MaxOutstanding,
                                                               RunTime) of
                    {ok, Pid} ->
                        MRef = erlang:monitor(process, Pid),
                        receive
                            {'DOWN', MRef, process, Pid, Info} ->
                                Info
                        end;
                    {error, Reason} ->
                        eprint(f("Failed starting server: "
                                 "~n   ~p", [Reason])),
                        error
                end
        end,
    case Profile of
        true ->
            socket_test_profile:profile(
              f("client-sock-~w-~w", [Method, Active]), Fun);
        false ->
            Fun()
    end;
exec(_) ->
    usage("Unexpected option combo"),
    ok.


%% ==========================================================================
    
f(F, A) ->
    socket_test_ttest_lib:format(F, A).

eprint(ErrorString) when is_list(ErrorString) ->
    print("<ERROR> " ++ ErrorString ++ "~n", []).

print(F, A) ->
    io:format(F ++ "~n", A).

scriptname() ->
    FullName = escript:script_name(),
    filename:basename(FullName).

